# -*- coding=utf-8 -*-
import numpy as np
from numpy.random import permutation
from fine_tune_pretraind_keras_vgg_jiao_dong import data_root, load_train, load_test, cache_data, \
    restore_data, read_and_normalize_test_data
import os
from keras.utils import np_utils

# X_train, y_train, driver_id, unique_drivers = load_train(64, 64, 1)
# print np.asarray(X_train).shape

use_cache = 1


def read_and_normalize_and_shuffle_train_data(img_rows, img_cols, color_type=1):
    cache_path = os.path.join(data_root, 'cache', 'train_r_' + str(img_rows) +
                              '_c_' + str(img_cols) + '_t_' +
                              str(color_type) + '.dat')
    if not os.path.isfile(cache_path) or use_cache == 0:
        train_data, train_target, driver_id, unique_drivers = \
            load_train(img_rows, img_cols, color_type)
        cache_data((train_data, train_target, driver_id, unique_drivers),
                   cache_path)
    else:
        print('Restore train from cache!')
        (train_data, train_target, driver_id, unique_drivers) = \
            restore_data(cache_path)

    train_data = np.asarray(train_data, dtype=np.uint8)
    train_target = np.asarray(train_target, dtype=np.uint8)

    if color_type == 1:
        train_data = train_data.reshape(train_data.shape[0], color_type, img_rows, img_cols)
    else:
        train_data = train_data.transpose((0, 3, 1, 2))

    train_target = np_utils.to_categorical(train_target, 10)
    train_data = train_data.astype('float32')
    mean_pixel = [103.939, 116.779, 123.68]
    for c in range(3):
        train_data[:, c, :, :] = train_data[:, c, :, :] - mean_pixel[c]
    # train_data /= 255
    perm = permutation(len(train_target))
    train_data = train_data[perm]
    train_target = train_target[perm]
    print('Train shape:', train_data.shape)
    print(train_data.shape[0], 'train samples')
    return train_data, train_target, driver_id, unique_drivers


# train_data, train_target, driver_id, unique_drivers=read_and_normalize_and_shuffle_train_data(224,224,3)
# print train_target
# print('Start testing............')
# test_data, test_id = read_and_normalize_test_data(224, 224,3)

#### test
# X_test,  test_id = load_test(img_rows, img_cols, color_type_global)
# print np.asarray(X_test, np.float32).shape
# X_train, y_train, driver_id, unique_drivers = load_train(img_rows, img_cols, color_type_global)
# print np.asarray(X_train, np.float32).shape
# print unique_drivers
# train_data, train_target, driver_id, unique_drivers = \
#    read_and_normalize_and_shuffle_train_data(img_rows, img_cols, color_type_global)

import glob
import math


def create_source_file():
    f_trian = open(data_root + '/train.txt', mode='w')
    for j in range(10):
        print('Load folder c{}'.format(j))
        train_path = os.path.join(data_root, 'imgs', 'train',
                                  'c' + str(j), '*.jpg')
        files = glob.glob(train_path)

        for fl in files:
            flbase = os.path.basename(fl)
            # print fl
            line = fl + ' ' + str(j) + '\n'
            # print flbase

            f_trian.writelines(line)
    f_trian.close()

    f_test = open(data_root + '/test.txt', mode='w')
    test_path = os.path.join(data_root, 'imgs', 'test', '*.jpg')
    files = glob.glob(test_path)
    X_test = []
    X_test_id = []
    total = 0
    thr = math.floor(len(files) / 10)
    for fl in files:
        flbase = os.path.basename(fl)
        f_test.writelines(fl + '\n')

    f_test.close()


# create_source_file()


# import caffe
import numpy as np
import sys

# blob = caffe.proto.caffe_pb2.BlobProto()
# data = open( file1 , 'rb' ).read()
# print data
"""
# 引入“咖啡”
import os
import sys
sys.path.append('/home/lab/wjSun/caffe-master/python')
import caffe

import numpy as np

# 使输出的参数完全显示
# 若没有这一句，因为参数太多，中间会以省略号“……”的形式代替
np.set_printoptions(threshold='nan')

# 均值文件
#MEAN_FILE = 'mean.binaryproto'
VGG_MEAN='/media/lab/labdisk/home/lab1314/xryu/pretrined_model/caffe/vgg/VGG_mean.binaryproto'
RESNET_MEAN='/media/lab/labdisk/home/lab1314/xryu/pretrined_model/caffe/ResNet/ResNet_mean.binaryproto'
# 保存参数的文件
means_txt = 'VGGmeans.txt'
resnet_means_txt = 'RESmean.txt'
mf = open(resnet_means_txt, 'w')

# 将均值文件读入blob中
mean_blob = caffe.proto.caffe_pb2.BlobProto()
mean_blob.ParseFromString(open(RESNET_MEAN, 'rb').read())

# 将均值blob转为numpy.array
mean_npy = caffe.io.blobproto_to_array(mean_blob)
# 均值参数是多维数组，为了方便输出，转为单列数组
mean_npy.shape = (-1, 1)
for m in mean_npy:
    # 写参数
    mf.write('%ff, ' % m)

mf.close
"""
if __name__ == '__main__':
    from kfold_cv_caffe import classify_extracted_feature_cls3_fc_ddd, classify, create_submission

    snapshot_root = '/media/wjsun/delldisk/dell/Rui/data/Kaggle/DistractedDriverDetection/fine-tune-caffe/snapshot/'
    TEST_IMG_ROOT = '/media/wjsun/delldisk/dell/Rui/data/Kaggle/DistractedDriverDetection/imgs/test'
    sample_img = '/media/wjsun/dell/Rui/data/Kaggle/DistractedDriverDetection/imgs/c4_p072_sample'
    mean_sub_sampe = '/media/wjsun/delldisk/dell/Rui/data/Kaggle/DistractedDriverDetection/imgs/test_mean_substracted_224/bgr'
    # googlenet
    googlenet_root = '/home/dell/Rui/kaggle_ddd_finetune_caffe/fine_tune_googlenet/'
    model_def = googlenet_root + '5foldcv/1crop/googlenet_test.prototxt'
    imagenet_mean_path = '/media/wjsun/delldisk/dell/Rui/pretrined_model/caffe/googlenet/imagenet_mean.binaryproto'
    info_string = 'LogisticRegression_extracted_cls3_fc_ddd_feature'

    # model_weights=snapshot_root+'5foldcv/1crop/googlenet/fix_0/_iter_150.caffemodel'
    model_weights = '/media/wjsun/delldisk/dell/Rui/data/Kaggle/DistractedDriverDetection/fine-tune-caffe/snapshot/' \
                    '5foldcv/1crop/googlenet/sigmoid_2/_iter_3000.caffemodel'

    preds, test_id, feats = classify_extracted_feature_cls3_fc_ddd(caffemodel=model_weights,
                                                                   deploy_file=model_def,
                                                                   test_image_files=TEST_IMG_ROOT,
                                                                   mean_file=imagenet_mean_path,
                                                                   labels_file=None,
                                                                   batch_size=1,
                                                                   use_gpu=True,
                                                                   DEBUG=False,
                                                                   EXTRACT=False,
                                                                   feature='cls3_fc_ddd'
                                                                   )
    # print feats[0]
    print 'creating submission file...'
    create_submission(preds, test_id, info_string)
    print 'done'
